import torch
import numpy as np
from numpy import arange
import math
# from parse import vis_const_prior, vis_FP, vis_FP_1
import torch.nn as nn
from timm.models.layers import drop_path, to_2tuple, trunc_normal_

print(torch.cuda.nccl.version())


seq_len = 20

a = torch.from_numpy(np.diag(np.ones(seq_len - 1, dtype=np.int32),1))    #对角线为1，且对角线元素上移一个单位


b = a.unsqueeze(0)


c = b.unsqueeze(1)

num_classes = 20
head = nn.Sequential(nn.Linear(23, num_classes) if num_classes > 0 else nn.Identity(),
                                    nn.Linear(num_classes, num_classes) if num_classes > 0 else nn.Identity())

init_scale = 0.

for i in head:
    trunc_normal_(i.weight, std=.02)

for hd in head:
    hd.weight.data.mul_(init_scale)
    hd.bias.data.mul_(init_scale) 

print("ss")